{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# **MIT 6.5940 EfficientML.ai Fall 2023: Lab 0 PyTorch Tutorial**"
   ],
   "metadata": {
    "id": "PFbUpH8elWQ7"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "In this tutorial, we will explore how to train a neural network with PyTorch."
   ],
   "metadata": {
    "id": "aoNr0MWd5e5m"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Setup"
   ],
   "metadata": {
    "id": "yoBtxdvR5lwM"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "We will first install a few packages that will be used in this tutorial:"
   ],
   "metadata": {
    "id": "0oLGv2RjLYh2"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "!pip install torchprofile 1>/dev/null"
   ],
   "metadata": {
    "id": "3r7Sl2cG7nZF"
   },
   "execution_count": 1,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m23.3.2\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m24.0\u001B[0m\r\n",
      "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpython3.11 -m pip install --upgrade pip\u001B[0m\r\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "We will then import a few libraries:"
   ],
   "metadata": {
    "id": "LYgp0au_LeAd"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "import random\n",
    "from collections import OrderedDict, defaultdict\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "from torch import nn\n",
    "from torch.optim import *\n",
    "from torch.optim.lr_scheduler import *\n",
    "from torch.utils.data import DataLoader\n",
    "from torchprofile import profile_macs\n",
    "from torchvision.datasets import *\n",
    "from torchvision.transforms import *\n",
    "from tqdm.auto import tqdm"
   ],
   "metadata": {
    "id": "I3uAhaCSlFrK"
   },
   "execution_count": 3,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "To ensure the reproducibility, we will control the seed of random generators:"
   ],
   "metadata": {
    "id": "u1Yx0rDUK5fx"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "random.seed(0)\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)"
   ],
   "metadata": {
    "id": "j_l1wEdeHOlu"
   },
   "execution_count": 4,
   "outputs": [
    {
     "data": {
      "text/plain": "<torch._C.Generator at 0x11ab391d0>"
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Data"
   ],
   "metadata": {
    "id": "u7Y0sLyajGAu"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "In this tutorial, we will use CIFAR-10 as our target dataset. This dataset contains images from 10 classes, where each image is of\n",
    "size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size."
   ],
   "metadata": {
    "id": "VAbL_li0KPsz"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "transforms = {\n",
    "  \"train\": Compose([\n",
    "    RandomCrop(32, padding=4),\n",
    "    RandomHorizontalFlip(),\n",
    "    ToTensor(),\n",
    "  ]),\n",
    "  \"test\": ToTensor(),\n",
    "}\n",
    "\n",
    "dataset = {}\n",
    "for split in [\"train\", \"test\"]:\n",
    "  dataset[split] = CIFAR10(\n",
    "    root=\"data/cifar10\",\n",
    "    train=(split == \"train\"),\n",
    "    download=True,\n",
    "    transform=transforms[split],\n",
    "  )"
   ],
   "metadata": {
    "id": "Pqhy8EJSjJfp"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "We can visualize a few images in the dataset and their corresponding class labels:"
   ],
   "metadata": {
    "id": "ft9wv-tIMUgl"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "samples = [[] for _ in range(10)]\n",
    "for image, label in dataset[\"test\"]:\n",
    "  if len(samples[label]) < 4:\n",
    "    samples[label].append(image)\n",
    "\n",
    "plt.figure(figsize=(20, 9))\n",
    "for index in range(40):\n",
    "  label = index % 10\n",
    "  image = samples[label][index // 10]\n",
    "\n",
    "  # Convert from CHW to HWC for visualization\n",
    "  image = image.permute(1, 2, 0)\n",
    "\n",
    "  # Convert from class index to class name\n",
    "  label = dataset[\"test\"].classes[label]\n",
    "\n",
    "  # Visualize the image\n",
    "  plt.subplot(4, 10, index + 1)\n",
    "  plt.imshow(image)\n",
    "  plt.title(label)\n",
    "  plt.axis(\"off\")\n",
    "plt.show()"
   ],
   "metadata": {
    "id": "ofwgqYb2qsd2"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "To train a neural network, we will need to feed data in batches. We create data loaders with batch size of 512:"
   ],
   "metadata": {
    "id": "jkigVqADNeIN"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "dataflow = {}\n",
    "for split in ['train', 'test']:\n",
    "  dataflow[split] = DataLoader(\n",
    "    dataset[split],\n",
    "    batch_size=512,\n",
    "    shuffle=(split == 'train'),\n",
    "    num_workers=0,\n",
    "    pin_memory=True,\n",
    "  )"
   ],
   "metadata": {
    "id": "4axnQCtnks_s"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "We can print the data type and shape from the training data loader:"
   ],
   "metadata": {
    "id": "_5G1Lf6hOLGT"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "for inputs, targets in dataflow[\"train\"]:\n",
    "  print(\"[inputs] dtype: {}, shape: {}\".format(inputs.dtype, inputs.shape))\n",
    "  print(\"[targets] dtype: {}, shape: {}\".format(targets.dtype, targets.shape))\n",
    "  break"
   ],
   "metadata": {
    "id": "ReP2g9pD6ppI"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Model"
   ],
   "metadata": {
    "id": "sPAEVnixjwb7"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "In this tutorial, we will use a variant of [VGG-11](https://arxiv.org/abs/1409.1556) (with fewer downsamples and a smaller classifier) as our model."
   ],
   "metadata": {
    "id": "rFr1Js3-e3rJ"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "class VGG(nn.Module):\n",
    "  ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']\n",
    "\n",
    "  def __init__(self) -> None:\n",
    "    super().__init__()\n",
    "\n",
    "    layers = []\n",
    "    counts = defaultdict(int)\n",
    "\n",
    "    def add(name: str, layer: nn.Module) -> None:\n",
    "      layers.append((f\"{name}{counts[name]}\", layer))\n",
    "      counts[name] += 1\n",
    "\n",
    "    in_channels = 3\n",
    "    for x in self.ARCH:\n",
    "      if x != 'M':\n",
    "        # conv-bn-relu\n",
    "        add(\"conv\", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))\n",
    "        add(\"bn\", nn.BatchNorm2d(x))\n",
    "        add(\"relu\", nn.ReLU(True))\n",
    "        in_channels = x\n",
    "      else:\n",
    "        # maxpool\n",
    "        add(\"pool\", nn.MaxPool2d(2))\n",
    "\n",
    "    self.backbone = nn.Sequential(OrderedDict(layers))\n",
    "    self.classifier = nn.Linear(512, 10)\n",
    "\n",
    "  def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "    # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]\n",
    "    x = self.backbone(x)\n",
    "\n",
    "    # avgpool: [N, 512, 2, 2] => [N, 512]\n",
    "    x = x.mean([2, 3])\n",
    "\n",
    "    # classifier: [N, 512] => [N, 10]\n",
    "    x = self.classifier(x)\n",
    "    return x\n",
    "\n",
    "model = VGG().cuda()"
   ],
   "metadata": {
    "id": "SNLdS_UQjyBf"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "Its backbone is composed of eight `conv-bn-relu` blocks interleaved with four `maxpool`'s to downsample the feature map by 2^4 = 16 times:"
   ],
   "metadata": {
    "id": "x19LMKQYw0DI"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "print(model.backbone)"
   ],
   "metadata": {
    "id": "BUWmYS2owzCe"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "After the feature map is pooled, its classifier predicts the final output with a linear layer:"
   ],
   "metadata": {
    "id": "mApr58LmyDqr"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "print(model.classifier)"
   ],
   "metadata": {
    "id": "S1GoSsh_yJgD"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "As this course focuses on efficiency, we will then inspect its model size and (theoretical) computation cost.\n"
   ],
   "metadata": {
    "id": "F_RcCWoQ8Kp1"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "* The model size can be estimated by the number of trainable parameters:"
   ],
   "metadata": {
    "id": "Zd4Xu-vMyz39"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "num_params = 0\n",
    "for param in model.parameters():\n",
    "  if param.requires_grad:\n",
    "    num_params += param.numel()\n",
    "print(\"#Params:\", num_params)"
   ],
   "metadata": {
    "id": "4gTfqC0B7Uzi"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "* The computation cost can be estimated by the number of [multiply–accumulate operations (MACs)](https://en.wikipedia.org/wiki/Multiply–accumulate_operation) using [TorchProfile](https://github.com/zhijian-liu/torchprofile):"
   ],
   "metadata": {
    "id": "uAZoIKIbzLa4"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "num_macs = profile_macs(model, torch.zeros(1, 3, 32, 32).cuda())\n",
    "print(\"#MACs:\", num_macs)"
   ],
   "metadata": {
    "id": "OKVmyWCN7qpp"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "This model has 9.2M parameters and requires 606M MACs for inference. We will work together in the next few labs to improve its efficiency."
   ],
   "metadata": {
    "id": "OYkqpfejzxwq"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Optimization"
   ],
   "metadata": {
    "id": "gjDsY9_KkIjZ"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "As we are working on a classification problem, we will apply [cross entropy](https://en.wikipedia.org/wiki/Cross_entropy) as our loss function to optimize the model:"
   ],
   "metadata": {
    "id": "oRg_5KeKLHPj"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "criterion = nn.CrossEntropyLoss()"
   ],
   "metadata": {
    "id": "-K0DEhGKkKfF"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "Optimization will be carried out using [stochastic gradient descent (SGD)](https://en.wikipedia.org/wiki/Stochastic_gradient_descent) with [momentum](https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Momentum):"
   ],
   "metadata": {
    "id": "3H8YniYeLIdg"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "optimizer = SGD(\n",
    "  model.parameters(),\n",
    "  lr=0.4,\n",
    "  momentum=0.9,\n",
    "  weight_decay=5e-4,\n",
    ")"
   ],
   "metadata": {
    "id": "HXANib83LATH"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "The learning rate will be modulated using the following scheduler (which is adapted from [this blog series](https://myrtle.ai/learn/how-to-train-your-resnet/)):"
   ],
   "metadata": {
    "id": "v9X8SiWYLJw2"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "num_epochs = 20\n",
    "steps_per_epoch = len(dataflow[\"train\"])\n",
    "\n",
    "# Define the piecewise linear scheduler\n",
    "lr_lambda = lambda step: np.interp(\n",
    "  [step / steps_per_epoch],\n",
    "  [0, num_epochs * 0.3, num_epochs],\n",
    "  [0, 1, 0]\n",
    ")[0]\n",
    "\n",
    "# Visualize the learning rate schedule\n",
    "steps = np.arange(steps_per_epoch * num_epochs)\n",
    "plt.plot(steps, [lr_lambda(step) * 0.4 for step in steps])\n",
    "plt.xlabel(\"Number of Steps\")\n",
    "plt.ylabel(\"Learning Rate\")\n",
    "plt.grid(\"on\")\n",
    "plt.show()\n",
    "\n",
    "scheduler = LambdaLR(optimizer, lr_lambda)"
   ],
   "metadata": {
    "id": "8mJU5aw8KrVX"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Training"
   ],
   "metadata": {
    "id": "i2UFRbRYly50"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "We first define the training function that optimizes the model for one epoch (*i.e.*, a pass over the training set):"
   ],
   "metadata": {
    "id": "IpHZJpjR7Wy3"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "def train(\n",
    "  model: nn.Module,\n",
    "  dataflow: DataLoader,\n",
    "  criterion: nn.Module,\n",
    "  optimizer: Optimizer,\n",
    "  scheduler: LambdaLR,\n",
    ") -> None:\n",
    "  model.train()\n",
    "\n",
    "  for inputs, targets in tqdm(dataflow, desc='train', leave=False):\n",
    "    # Move the data from CPU to GPU\n",
    "    inputs = inputs.cuda()\n",
    "    targets = targets.cuda()\n",
    "\n",
    "    # Reset the gradients (from the last iteration)\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    # Forward inference\n",
    "    outputs = model(inputs)\n",
    "    loss = criterion(outputs, targets)\n",
    "\n",
    "    # Backward propagation\n",
    "    loss.backward()\n",
    "\n",
    "    # Update optimizer and LR scheduler\n",
    "    optimizer.step()\n",
    "    scheduler.step()"
   ],
   "metadata": {
    "id": "79GKx_oVl09b"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "We then define the evaluation function that calculates the metric (*i.e.*, accuracy in our case) on the test set:"
   ],
   "metadata": {
    "id": "SwAbMdUq7YrE"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "@torch.inference_mode()\n",
    "def evaluate(\n",
    "  model: nn.Module,\n",
    "  dataflow: DataLoader\n",
    ") -> float:\n",
    "  model.eval()\n",
    "\n",
    "  num_samples = 0\n",
    "  num_correct = 0\n",
    "\n",
    "  for inputs, targets in tqdm(dataflow, desc=\"eval\", leave=False):\n",
    "    # Move the data from CPU to GPU\n",
    "    inputs = inputs.cuda()\n",
    "    targets = targets.cuda()\n",
    "\n",
    "    # Inference\n",
    "    outputs = model(inputs)\n",
    "\n",
    "    # Convert logits to class indices\n",
    "    outputs = outputs.argmax(dim=1)\n",
    "\n",
    "    # Update metrics\n",
    "    num_samples += targets.size(0)\n",
    "    num_correct += (outputs == targets).sum()\n",
    "\n",
    "  return (num_correct / num_samples * 100).item()"
   ],
   "metadata": {
    "id": "44_AriMP4G_A"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "With training and evaluation functions, we can finally start training the model! This will take around 10 minutes."
   ],
   "metadata": {
    "id": "6XSU9oFD7aXs"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "for epoch_num in tqdm(range(1, num_epochs + 1)):\n",
    "  train(model, dataflow[\"train\"], criterion, optimizer, scheduler)\n",
    "  metric = evaluate(model, dataflow[\"test\"])\n",
    "  print(f\"epoch {epoch_num}:\", metric)"
   ],
   "metadata": {
    "id": "4iWaYpw_4E8f"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "If everything goes well, your trained model should be able to achieve >92.5\\% of accuracy!"
   ],
   "metadata": {
    "id": "96XTuBat-e-d"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Visualization"
   ],
   "metadata": {
    "id": "ck6iME0rLjuk"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "We can visualize the model's prediction to see how the model truly performs:"
   ],
   "metadata": {
    "id": "-8rVV1SOSUsR"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "for index in range(40):\n",
    "  image, label = dataset[\"test\"][index]\n",
    "\n",
    "  # Model inference\n",
    "  model.eval()\n",
    "  with torch.inference_mode():\n",
    "    pred = model(image.unsqueeze(dim=0).cuda())\n",
    "    pred = pred.argmax(dim=1)\n",
    "\n",
    "  # Convert from CHW to HWC for visualization\n",
    "  image = image.permute(1, 2, 0)\n",
    "\n",
    "  # Convert from class indices to class names\n",
    "  pred = dataset[\"test\"].classes[pred]\n",
    "  label = dataset[\"test\"].classes[label]\n",
    "\n",
    "  # Visualize the image\n",
    "  plt.subplot(4, 10, index + 1)\n",
    "  plt.imshow(image)\n",
    "  plt.title(f\"pred: {pred}\" + \"\\n\" + f\"label: {label}\")\n",
    "  plt.axis(\"off\")\n",
    "plt.show()"
   ],
   "metadata": {
    "id": "inwZEfX3Mo6A"
   },
   "execution_count": null,
   "outputs": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
